{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1e8ed773-0da3-49f2-8481-2ac87a12926f",
   "metadata": {
    "id": "1e8ed773-0da3-49f2-8481-2ac87a12926f"
   },
   "source": [
    "# 基于 SQL 和 Jina Reranker v2 的 RAG\n",
    "\n",
    "_作者：[Scott Martens](https://github.com/scott-martens) @ [Jina AI](https://jina.ai)_\n",
    "\n",
    "本教程将展示如何构建一个简单的检索增强生成（RAG）系统，该系统从 SQL 数据库中提取信息，而不是从文档存储中提取。\n",
    "\n",
    "### 工作原理\n",
    "\n",
    "* 给定一个 SQL 数据库，我们提取 SQL 表的定义（SQL 导出文件中的 `CREATE` 语句），并将其存储。在本教程中，我们已经为您完成了这部分操作，表定义被存储在内存中，作为一个列表。根据此示例扩展可能需要更复杂的存储方案。\n",
    "* 用户输入一个自然语言查询。\n",
    "* [Jina Reranker v2](https://jina.ai/reranker/)（[`jinaai/jina-reranker-v2-base-multilingual`](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual)），一个由 [Jina AI](https://jina.ai) 提供的 SQL 感知排序模型，会根据查询的相关性对表定义进行排序。\n",
    "* 我们将用户的查询和排名前三的表定义作为提示，传递给 [Mistral 7B Instruct v0.1 \\(`mistralai/Mistral-7B-Instruct-v0.1`)](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)，并请求生成一个 SQL 查询来完成任务。\n",
    "* Mistral Instruct 生成一个 SQL 查询，我们将其在数据库上执行并检索结果。\n",
    "* SQL 查询结果被转换为 JSON 格式，并作为新提示传递给 Mistral Instruct，包含用户的原始查询、SQL 查询及请求，要求生成自然语言形式的答案。\n",
    "* Mistral Instruct 的自然语言文本响应返回给用户。\n",
    "\n",
    "### 数据库\n",
    "\n",
    "本教程使用一个小型的开放访问视频游戏销售记录数据库，存储在 [GitHub](https://github.com/bbrumm/databasestar/tree/main/sample_databases/sample_db_videogames/sqlite) 上。我们将使用 [SQLite](https://www.sqlite.org/index.html) 版本，因为 SQLite 非常紧凑，跨平台，并且内置对 Python 的支持。\n",
    "\n",
    "### 软件和硬件要求\n",
    "\n",
    "我们将在本地运行 Jina Reranker v2 模型。如果您使用 Google Colab 运行此笔记本，请确保使用支持 GPU 的运行时。如果您在本地运行，您需要 Python 3（本教程使用 Python 3.11 编写），并且在启用了 CUDA 的 GPU 上运行将会 *大大* 提升速度。\n",
    "\n",
    "本教程还将广泛使用开源的 [LlamaIndex RAG 框架](https://www.llamaindex.ai/)，以及 [Hugging Face Inference API](https://huggingface.co/inference-api/serverless) 来访问 Mistral 7B Instruct v0.1。您需要一个 [Hugging Face 账户](https://huggingface.co/login) 和一个至少具有 `READ` 权限的 [访问令牌](https://huggingface.co/settings/tokens)。\n",
    "\n",
    "> [!WARNING]\n",
    "> 如果你使用 Google Colab，SQLite 已经安装。它可能没有安装在您的本地计算机上。如果未安装，请按照 [SQLite 网站](https://www.sqlite.org/download.html) 上的说明进行安装。Python 接口代码已经集成在 Python 中，无需额外安装任何 Python 模块。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "G18FHmeE5_5Q",
   "metadata": {
    "id": "G18FHmeE5_5Q"
   },
   "source": [
    "## 开始"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c22zH-xVFEvV",
   "metadata": {
    "id": "c22zH-xVFEvV"
   },
   "source": [
    "### 安装环境\n",
    "\n",
    "首先，安装需要的 python 模块："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "674424a4-8371-4010-a4b2-650d427a556d",
   "metadata": {
    "id": "674424a4-8371-4010-a4b2-650d427a556d"
   },
   "outputs": [],
   "source": [
    "!pip install -qU transformers einops llama-index llama-index-postprocessor-jinaai-rerank  llama-index-llms-huggingface \"huggingface_hub[inference]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "EzhBY8pq_av9",
   "metadata": {
    "id": "EzhBY8pq_av9"
   },
   "source": [
    "### 下载数据库\n",
    "\n",
    "接下来，从 [GitHub](https://github.com/bbrumm/databasestar/tree/main/sample_databases/sample_db_videogames/sqlite) 下载 SQLite 数据库 `videogames.db` 到本地文件系统。如果你的系统上没有 `wget` 命令，可以通过 [这个链接](https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db) 下载数据库，并将其放置在你运行本 Notebook 的相同目录中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec28e417-a4cb-4fab-a614-4867b84ec7f4",
   "metadata": {
    "id": "ec28e417-a4cb-4fab-a614-4867b84ec7f4"
   },
   "outputs": [],
   "source": [
    "!wget https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3GY3Q13NINd0",
   "metadata": {
    "id": "3GY3Q13NINd0"
   },
   "source": [
    "### 下载并运行 Jina Reranker v2\n",
    "\n",
    "以下代码将下载模型 `jina-reranker-v2-base-multilingual` 并在本地运行： \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "KB8xGaNXIjAC",
   "metadata": {
    "id": "KB8xGaNXIjAC"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "reranker_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    'jinaai/jina-reranker-v2-base-multilingual',\n",
    "    torch_dtype=\"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "reranker_model.to('cuda') # or 'cpu' if no GPU is available\n",
    "reranker_model.eval()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "SGgk4EmkOcy4",
   "metadata": {
    "id": "SGgk4EmkOcy4"
   },
   "source": [
    "### 设置 Mistral Instruct 的接口\n",
    "\n",
    "我们将使用 LlamaIndex 创建一个持有对象，用于连接 Hugging Face 推理 API 和运行在那里的 `mistralai/Mistral-7B-Instruct-v0.1` 模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee7bedf5-4730-4ac1-8a93-7f0135779d9b",
   "metadata": {
    "id": "ee7bedf5-4730-4ac1-8a93-7f0135779d9b"
   },
   "source": [
    "首先，从你的 [Hugging Face 账户设置页面](https://huggingface.co/settings/tokens) 获取一个 Hugging Face 访问令牌。\n",
    "\n",
    "在下面的提示中输入该令牌："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54122f11-e0c3-4e89-bf54-f0eece57cf71",
   "metadata": {
    "id": "54122f11-e0c3-4e89-bf54-f0eece57cf71"
   },
   "outputs": [],
   "source": [
    "import getpass\n",
    "\n",
    "print(\"Paste your Hugging Face access token here: \")\n",
    "hf_token = getpass.getpass()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4HpY0xohPCWW",
   "metadata": {
    "id": "4HpY0xohPCWW"
   },
   "source": [
    "接下来，初始化 LlamaIndex 中 `HuggingFaceInferenceAPI` 类的实例，并将其存储为 `mistral_llm`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4a2ea12-d5d8-469b-afac-af9aed5b2240",
   "metadata": {
    "id": "e4a2ea12-d5d8-469b-afac-af9aed5b2240"
   },
   "outputs": [],
   "source": [
    "from llama_index.llms.huggingface import HuggingFaceInferenceAPI\n",
    "\n",
    "mistral_llm = HuggingFaceInferenceAPI(\n",
    "    model_name=\"mistralai/Mixtral-8x7B-Instruct-v0.1\", token=hf_token\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9257efd9-806a-4ada-871a-406b397fe127",
   "metadata": {
    "id": "9257efd9-806a-4ada-871a-406b397fe127"
   },
   "source": [
    "## 使用 SQL 感知的 Jina Reranker v2\n",
    "\n",
    "我们从 [GitHub 上的数据库导入文件](https://github.com/bbrumm/databasestar/tree/main/sample_databases/sample_db_videogames/sqlite) 中提取了八个表的定义。运行以下命令，将它们放入名为 `table_declarations` 的 Python 列表中："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74677bb6-cec4-4ddc-8379-d13fb3366fb1",
   "metadata": {
    "id": "74677bb6-cec4-4ddc-8379-d13fb3366fb1"
   },
   "outputs": [],
   "source": [
    "table_declarations = ['CREATE TABLE platform (\\n\\tid INTEGER PRIMARY KEY,\\n\\tplatform_name TEXT DEFAULT NULL\\n);',\n",
    " 'CREATE TABLE genre (\\n\\tid INTEGER PRIMARY KEY,\\n\\tgenre_name TEXT DEFAULT NULL\\n);',\n",
    " 'CREATE TABLE publisher (\\n\\tid INTEGER PRIMARY KEY,\\n\\tpublisher_name TEXT DEFAULT NULL\\n);',\n",
    " 'CREATE TABLE region (\\n\\tid INTEGER PRIMARY KEY,\\n\\tregion_name TEXT DEFAULT NULL\\n);',\n",
    " 'CREATE TABLE game (\\n\\tid INTEGER PRIMARY KEY,\\n\\tgenre_id INTEGER,\\n\\tgame_name TEXT DEFAULT NULL,\\n\\tCONSTRAINT fk_gm_gen FOREIGN KEY (genre_id) REFERENCES genre(id)\\n);',\n",
    " 'CREATE TABLE game_publisher (\\n\\tid INTEGER PRIMARY KEY,\\n\\tgame_id INTEGER DEFAULT NULL,\\n\\tpublisher_id INTEGER DEFAULT NULL,\\n\\tCONSTRAINT fk_gpu_gam FOREIGN KEY (game_id) REFERENCES game(id),\\n\\tCONSTRAINT fk_gpu_pub FOREIGN KEY (publisher_id) REFERENCES publisher(id)\\n);',\n",
    " 'CREATE TABLE game_platform (\\n\\tid INTEGER PRIMARY KEY,\\n\\tgame_publisher_id INTEGER DEFAULT NULL,\\n\\tplatform_id INTEGER DEFAULT NULL,\\n\\trelease_year INTEGER DEFAULT NULL,\\n\\tCONSTRAINT fk_gpl_gp FOREIGN KEY (game_publisher_id) REFERENCES game_publisher(id),\\n\\tCONSTRAINT fk_gpl_pla FOREIGN KEY (platform_id) REFERENCES platform(id)\\n);',\n",
    " 'CREATE TABLE region_sales (\\n\\tregion_id INTEGER DEFAULT NULL,\\n\\tgame_platform_id INTEGER DEFAULT NULL,\\n\\tnum_sales REAL,\\n   CONSTRAINT fk_rs_gp FOREIGN KEY (game_platform_id) REFERENCES game_platform(id),\\n\\tCONSTRAINT fk_rs_reg FOREIGN KEY (region_id) REFERENCES region(id)\\n);']\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "I8ZRU3bxQZUf",
   "metadata": {
    "id": "I8ZRU3bxQZUf"
   },
   "source": [
    "现在，我们定义一个函数，该函数接受一个自然语言查询和表定义列表，使用 Jina Reranker v2 对所有表进行评分，并按得分从高到低返回它们："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "kJ_JwS-vQkj9",
   "metadata": {
    "id": "kJ_JwS-vQkj9"
   },
   "outputs": [],
   "source": [
    "from typing import List, Tuple\n",
    "\n",
    "def rank_tables(query: str, table_specs: List[str], top_n:int=0) -> List[Tuple[float, str]]:\n",
    "  \"\"\"\n",
    "  Get sorted pairs of scores and table specifications, then return the top N,\n",
    "  or all if top_n is 0 or default.\n",
    "  \"\"\"\n",
    "  pairs = [[query, table_spec] for table_spec in table_specs]\n",
    "  scores = reranker_model.compute_score(pairs)\n",
    "  scored_tables = [(score, table_spec) for score, table_spec in zip(scores, table_specs)]\n",
    "  scored_tables.sort(key=lambda x: x[0], reverse=True)\n",
    "  if top_n and top_n < len(scored_tables):\n",
    "    return scored_tables[0:top_n]\n",
    "  return scored_tables"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "vk1YYqzesB-f",
   "metadata": {
    "id": "vk1YYqzesB-f"
   },
   "source": [
    "Jina Reranker v2 会对我们提供的每个表定义进行评分，默认情况下，这个函数将返回所有表及其得分。可选参数 `top_n` 限制返回的结果数量，按得分从高到低，直到用户定义的数量。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "zD1rdBiYQa_H",
   "metadata": {
    "id": "zD1rdBiYQa_H"
   },
   "source": [
    "试试这个。首先，定义一个查询："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4859647d-2bc2-4238-a5a1-235bed4c4f55",
   "metadata": {
    "id": "4859647d-2bc2-4238-a5a1-235bed4c4f55"
   },
   "outputs": [],
   "source": [
    "user_query = \"Identify the top 10 platforms by total sales.\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "H0O6qqMGRfu2",
   "metadata": {
    "id": "H0O6qqMGRfu2"
   },
   "source": [
    "运行 `rank_tables` 来获取表定义的列表。我们将 `top_n` 设置为 3，以限制返回列表的大小，并将结果赋值给变量 `ranked_tables`，然后检查结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0OU98s_aRWqM",
   "metadata": {
    "id": "0OU98s_aRWqM"
   },
   "outputs": [],
   "source": [
    "ranked_tables = rank_tables(user_query, table_declarations, top_n=3)\n",
    "ranked_tables"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1Oh7cX_WT-bc",
   "metadata": {
    "id": "1Oh7cX_WT-bc"
   },
   "source": [
    "输出应该包括 `region_sales`、`platform` 和 `game_platform` 这三个表，它们似乎都是查找查询答案的合理地方。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adf28cd8-974a-457d-90b6-8719d8d860fc",
   "metadata": {
    "id": "adf28cd8-974a-457d-90b6-8719d8d860fc"
   },
   "source": [
    "## 使用 Mistral Instruct 生成 SQL 查询\n",
    "\n",
    "我们将使用 Mistral Instruct v0.1 编写一个 SQL 查询，满足用户的查询需求，基于根据重新排序器得出的前三个表的声明。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df5d5f53-9cac-42ad-8ac8-627882e1046c",
   "metadata": {
    "id": "df5d5f53-9cac-42ad-8ac8-627882e1046c"
   },
   "source": [
    "首先，我们使用 LlamaIndex 的 `PromptTemplate` 类为此目的创建一个提示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "660c810a-96ff-4726-9bf1-8084ada9a80e",
   "metadata": {
    "id": "660c810a-96ff-4726-9bf1-8084ada9a80e"
   },
   "outputs": [],
   "source": [
    "from llama_index.core import PromptTemplate\n",
    "\n",
    "make_sql_prompt_tmpl_text = (\n",
    "    \"\"\"\n",
    "Generate a SQL query to answer the following question from the user:\n",
    "\\\"{query_str}\\\"\n",
    "\n",
    "The SQL query should use only tables with the following SQL definitions:\n",
    "\n",
    "Table 1:\n",
    "{table_1}\n",
    "\n",
    "Table 2:\n",
    "{table_2}\n",
    "\n",
    "Table 3:\n",
    "{table_3}\n",
    "\n",
    "Make sure you ONLY output an SQL query and no explanation.\n",
    "\"\"\"\n",
    ")\n",
    "make_sql_prompt_tmpl = PromptTemplate(make_sql_prompt_tmpl_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "594693a2-deb0-4c58-bc82-cb3d6428b1a4",
   "metadata": {
    "id": "594693a2-deb0-4c58-bc82-cb3d6428b1a4"
   },
   "source": [
    "我们使用 `format` 方法将用户查询和来自 Jina Reranker v2 的前三个表定义填充到模板字段中："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f620d53-f459-45f9-936f-1d16a97ae357",
   "metadata": {
    "id": "9f620d53-f459-45f9-936f-1d16a97ae357"
   },
   "outputs": [],
   "source": [
    "make_sql_prompt = make_sql_prompt_tmpl.format(query_str=user_query,\n",
    "                                              table_1=ranked_tables[0][1],\n",
    "                                              table_2=ranked_tables[1][1],\n",
    "                                              table_3=ranked_tables[2][1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dafb58d0-2bfb-412a-ae0b-6cb88a797045",
   "metadata": {
    "id": "dafb58d0-2bfb-412a-ae0b-6cb88a797045"
   },
   "source": [
    "你可以看到我们将传递给 Mistral Instruct 的实际文本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf1cd783-e823-4292-8b60-42a10c2e54c0",
   "metadata": {
    "id": "bf1cd783-e823-4292-8b60-42a10c2e54c0"
   },
   "outputs": [],
   "source": [
    "print(make_sql_prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df1d4d6b-5e8f-45fd-94bc-1bfe1819ea7a",
   "metadata": {
    "id": "df1d4d6b-5e8f-45fd-94bc-1bfe1819ea7a"
   },
   "source": [
    "现在，让我们将提示发送给 Mistral Instruct 并获取其响应："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c099c16b-4295-45e7-8dc0-a1ec4e1691b4",
   "metadata": {
    "id": "c099c16b-4295-45e7-8dc0-a1ec4e1691b4"
   },
   "outputs": [],
   "source": [
    "response = mistral_llm.complete(make_sql_prompt)\n",
    "sql_query = str(response)\n",
    "print(sql_query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca5291a8-6435-4832-b9e5-8549fae2249b",
   "metadata": {
    "id": "ca5291a8-6435-4832-b9e5-8549fae2249b"
   },
   "source": [
    "## 运行 SQL 查询\n",
    "使用内置的 Python SQLite 接口，针对数据库 `videogames.db` 运行上面的 SQL 查询："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08492c60-4c45-40a3-9b90-fd1b3abafea9",
   "metadata": {
    "id": "08492c60-4c45-40a3-9b90-fd1b3abafea9"
   },
   "outputs": [],
   "source": [
    "import sqlite3\n",
    "\n",
    "con = sqlite3.connect(\"videogames.db\")\n",
    "cur = con.cursor()\n",
    "sql_response = cur.execute(sql_query).fetchall()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "LHqzMjSGrEKx",
   "metadata": {
    "id": "LHqzMjSGrEKx"
   },
   "source": [
    "有关 SQLite 接口的详细信息，请参阅 [Python3 文档](https://docs.python.org/3/library/sqlite3.html)。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d8ea35a-04b5-477f-84d2-22b78fdd450f",
   "metadata": {
    "id": "7d8ea35a-04b5-477f-84d2-22b78fdd450f"
   },
   "source": [
    "检查结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "207b4e03-2061-4868-9d10-245c90615724",
   "metadata": {
    "id": "207b4e03-2061-4868-9d10-245c90615724"
   },
   "outputs": [],
   "source": [
    "sql_response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f50e18e-1771-4cca-b737-9362bb2c84a3",
   "metadata": {
    "id": "2f50e18e-1771-4cca-b737-9362bb2c84a3"
   },
   "source": [
    "你可以通过运行您自己的 SQL 查询来检查结果是否正确。该数据库中存储的销售数据是浮动点数，可能是以千或百万为单位的销售数量。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "286cae31-b645-48a3-8c1c-021f44f9f77d",
   "metadata": {
    "id": "286cae31-b645-48a3-8c1c-021f44f9f77d"
   },
   "source": [
    "## 获取自然语言回答\n",
    "\n",
    "现在，我们将用户的查询、SQL 查询和结果通过一个新的提示模板传递回 Mistral Instruct。\n",
    "\n",
    "首先，使用 LlamaIndex 创建新的提示模板，和之前一样："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c9260a-62c0-4643-ba51-28432868c108",
   "metadata": {
    "id": "48c9260a-62c0-4643-ba51-28432868c108"
   },
   "outputs": [],
   "source": [
    "rag_prompt_tmpl_str = (\n",
    "    \"\"\"\n",
    "Use the information in the JSON table to answer the following user query.\n",
    "Do not explain anything, just answer concisely. Use natural language in your\n",
    "answer, not computer formatting.\n",
    "\n",
    "USER QUERY: {query_str}\n",
    "\n",
    "JSON table:\n",
    "{json_table}\n",
    "\n",
    "This table was generated by the following SQL query:\n",
    "{sql_query}\n",
    "\n",
    "Answer ONLY using the information in the table and the SQL query, and if the\n",
    "table does not provide the information to answer the question, answer\n",
    "\"No Information\".\n",
    "\"\"\"\n",
    ")\n",
    "rag_prompt_tmpl = PromptTemplate(rag_prompt_tmpl_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78576d8c-30a0-4ab3-99ce-24c50d18a7b8",
   "metadata": {
    "id": "78576d8c-30a0-4ab3-99ce-24c50d18a7b8"
   },
   "source": [
    "我们将把 SQL 输出转换为 JSON 格式，这是 Mistral Instruct v0.1 理解的格式。\n",
    "\n",
    "填充模板字段："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5631fa1e-dd8f-42db-82a7-f1e3a5544a40",
   "metadata": {
    "id": "5631fa1e-dd8f-42db-82a7-f1e3a5544a40"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "rag_prompt = rag_prompt_tmpl.format(query_str=\"Identify the top 10 platforms by total sales\",\n",
    "                                    json_table=json.dumps(sql_response),\n",
    "                                    sql_query=sql_query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03c1092f-6715-459e-9402-1618729d419c",
   "metadata": {
    "id": "03c1092f-6715-459e-9402-1618729d419c"
   },
   "source": [
    "现在从 Mistral Instruct 请求自然语言回答："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70efddbb-8514-4985-9126-bbb9a55d0c63",
   "metadata": {
    "id": "70efddbb-8514-4985-9126-bbb9a55d0c63"
   },
   "outputs": [],
   "source": [
    "rag_response = mistral_llm.complete(rag_prompt)\n",
    "print(str(rag_response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60d0fe8f-1f67-44fb-bd90-a8a786086e99",
   "metadata": {
    "id": "60d0fe8f-1f67-44fb-bd90-a8a786086e99"
   },
   "source": [
    "## 尝试自己动手\n",
    "\n",
    "让我们将所有步骤组织成一个函数，并加入异常处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bbc48a6-e456-4b1f-9949-720b18230c8b",
   "metadata": {
    "id": "1bbc48a6-e456-4b1f-9949-720b18230c8b"
   },
   "outputs": [],
   "source": [
    "def answer_sql(user_query: str) -> str:\n",
    "  try:\n",
    "    ranked_tables = rank_tables(user_query, table_declarations, top_n=3)\n",
    "  except Exception as e:\n",
    "    print(f\"Ranking failed.\\nUser query:\\n{user_query}\\n\\n\")\n",
    "    raise(e)\n",
    "\n",
    "  make_sql_prompt = make_sql_prompt_tmpl.format(query_str=user_query,\n",
    "                                                table_1=ranked_tables[0][1],\n",
    "                                                table_2=ranked_tables[1][1],\n",
    "                                                table_3=ranked_tables[2][1])\n",
    "\n",
    "  try:\n",
    "    response = mistral_llm.complete(make_sql_prompt)\n",
    "  except Exception as e:\n",
    "    print(f\"SQL query generation failed\\nPrompt:\\n{make_sql_prompt}\\n\\n\")\n",
    "    raise(e)\n",
    "\n",
    "  # Backslash removal is a necessary hack because sometimes Mistral puts them\n",
    "  # in its generated code.\n",
    "  sql_query = str(response).replace(\"\\\\\", \"\")\n",
    "\n",
    "  try:\n",
    "    sql_response = sqlite3.connect(\"videogames.db\").cursor().execute(sql_query).fetchall()\n",
    "  except Exception as e:\n",
    "    print(f\"SQL querying failed. Query:\\n{sql_query}\\n\\n\")\n",
    "    raise(e)\n",
    "\n",
    "  rag_prompt = rag_prompt_tmpl.format(query_str=user_query,\n",
    "                                      json_table=json.dumps(sql_response),\n",
    "                                      sql_query=sql_query)\n",
    "  try:\n",
    "    rag_response = mistral_llm.complete(rag_prompt)\n",
    "    return str(rag_response)\n",
    "  except Exception as e:\n",
    "    print(f\"Answer generation failed. Prompt:\\n{rag_prompt}\\n\\n\")\n",
    "    raise(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5af30fba-d380-42da-97b8-00e5ead978e8",
   "metadata": {
    "id": "5af30fba-d380-42da-97b8-00e5ead978e8"
   },
   "source": [
    "尝试:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce3d50d-6f96-4647-a584-c9f70e400449",
   "metadata": {
    "id": "5ce3d50d-6f96-4647-a584-c9f70e400449"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"Identify the top 10 platforms by total sales.\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fde4b5a-1b1f-4142-9c8b-30695149db6b",
   "metadata": {
    "id": "0fde4b5a-1b1f-4142-9c8b-30695149db6b"
   },
   "source": [
    "试一试其他的问题:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00d4071c-657f-495d-ba8f-112bb7286943",
   "metadata": {
    "id": "00d4071c-657f-495d-ba8f-112bb7286943"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"Summarize sales by region.\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc6d927e-be00-4476-936f-34acdf7730fe",
   "metadata": {
    "id": "dc6d927e-be00-4476-936f-34acdf7730fe"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"List the publisher with the largest number of published games.\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4522592-4b82-440c-afa4-7379dedc19d8",
   "metadata": {
    "id": "e4522592-4b82-440c-afa4-7379dedc19d8"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"Display the year with most games released.\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95550428-4a0b-4324-bb0c-719ba0026b77",
   "metadata": {
    "id": "95550428-4a0b-4324-bb0c-719ba0026b77"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"What is the most popular game genre on the Wii platform?\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3de1f5-7e62-4370-a6eb-311485031fe1",
   "metadata": {
    "id": "eb3de1f5-7e62-4370-a6eb-311485031fe1"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"What is the most popular game genre of 2012?\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cdffc3f-5636-4383-bccb-b7dfa1f2745a",
   "metadata": {
    "id": "9cdffc3f-5636-4383-bccb-b7dfa1f2745a"
   },
   "source": [
    "试一试你自己的问题:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "553c721f-f80c-4e2c-a77f-fb2b59d1b555",
   "metadata": {
    "id": "553c721f-f80c-4e2c-a77f-fb2b59d1b555"
   },
   "outputs": [],
   "source": [
    "print(answer_sql(\"<INSERT QUESTION OR INSTRUCTION HERE>\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "-srB-MOMk1b-",
   "metadata": {
    "id": "-srB-MOMk1b-"
   },
   "source": [
    "## 复习与总结\n",
    "\n",
    "我们向你展示了如何构建一个非常基础的 RAG（检索增强生成）系统，用于自然语言问答，并将 SQL 数据库作为信息来源。在这个实现中，我们使用相同的大型语言模型（Mistral Instruct v0.1）来生成 SQL 查询和构造自然语言回答。\n",
    "\n",
    "这里的数据库是一个非常小的示例，扩展到更大规模可能需要比仅仅对表定义进行排序更复杂的方法。你可能需要使用一个双阶段的过程，其中嵌入模型和向量存储首先检索更多的结果，但重排序模型会将结果修剪到你能够放入生成语言模型提示中的数量。\n",
    "\n",
    "本 Notebook 假设没有任何请求需要超过三个表来满足，显然，在实际应用中，这种假设并不总是成立。Mistral 7B Instruct v0.1 并不保证生成正确（甚至是可执行的）SQL 输出。在生产环境中，类似的实现需要更深入的错误处理。\n",
    "\n",
    "更复杂的错误处理、更长的输入上下文窗口以及专门用于 SQL 任务的生成模型，可能在实际应用中带来显著的改进。\n",
    "\n",
    "尽管如此，你可以看到 RAG 概念如何扩展到结构化数据库，极大地扩展了其应用范围。"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
